import pandas as p
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from AnalyzeHistoricalNHTS import getVehicleDataByYear, getPersonDataByYear, getHouseholdDataByYear, get_state_safety_inspection_df_by_year, get_state_safety_inspection_dict
from tqdm import tqdm

def plotText(y_high=17.5, y_low=5):
    plt.text(1, y_high, 'Households not subject\nto safety inspections', color=noSafetyColor, ha='left')
    plt.text(1, y_low, 'Households subject\nto safety inspections', color=safetyColor, ha='left')

if __name__=='__main__':
    year = 2017
    vehData = getVehicleDataByYear(year)

    safety_inspection_df = get_state_safety_inspection_df_by_year(year)

    vehData = vehData.loc[vehData.loc[:, 'VEHAGE'] >= 0, :]
    vehData = vehData.merge(safety_inspection_df, left_on='HHSTATE', right_on='State Code')

    vehDataSmall = vehData.loc[:, ['HOUSEID', 'VEHAGE', 'HasSafety', 'HHSIZE', 'WTHHFIN', 'DRVRCNT', 'ANNMILES', 'VEHYEAR']]

    vehDataSmall.loc[:, 'New'] = 0
    vehDataSmall.loc[vehData.loc[:, 'VEHYEAR'] >= 2016, 'New'] = 1

    odomDataSmall = vehDataSmall.loc[vehDataSmall.loc[:, 'ANNMILES']>=0, :]

    vehDataSmall.loc[:, 'Subject to Safety Inspection'] = vehDataSmall.loc[:, 'HasSafety']
    odomDataSmall.loc[:, 'Subject to Safety Inspection'] = odomDataSmall.loc[:, 'HasSafety']

    vehData_min = vehDataSmall.groupby('HOUSEID', as_index=False).min()
    vehData_max = vehDataSmall.groupby('HOUSEID', as_index=False).max()
    vehData_mean = vehDataSmall.groupby('HOUSEID', as_index=False).mean()
    vehDataSmall.loc[:, 'Vehicle Count'] = 1
    vehData_count = vehDataSmall.groupby('HOUSEID', as_index=False).sum()
    cols_to_divide = ['WTHHFIN', "DRVRCNT", 'Subject to Safety Inspection', 'VEHAGE', 'HasSafety', 'HHSIZE']
    for col in cols_to_divide:
        vehData_count.loc[:, col] = (vehData_count.loc[:, col]/vehData_count.loc[:, 'Vehicle Count'])
    vehData_count.loc[:, 'Share New'] = vehData_count.loc[:, 'New']/vehData_count.loc[:, 'Vehicle Count']
    x = 5
    gr = (1 + np.sqrt(5)) / 2
    y = x / gr
    noSafetyColor = 'xkcd:red'
    safetyColor = 'xkcd:blue'
    x_lim = [0, 9]
    y_lim = [0, 24]

    # #With respect to number of people in Household
    # #Plot of Minimum Vehicle Age
    # fig, ax = plt.subplots(1,1,figsize=(x,y))
    # ax = sns.lineplot(data=vehData_min, x='HHSIZE', y='VEHAGE', hue='Subject to Safety Inspection', weights='WTHHFIN',
    #                   palette=[noSafetyColor, safetyColor], legend=False)
    # plt.ylabel('Minimum Household Vehicle Age')
    # plt.xlabel('Number of People In Household')
    # plt.xlim(x_lim)
    # plt.ylim(y_lim)
    # plotText(12, 3)
    # plt.title('Minimum Vehicle Age vs. Household Size\nBy Safety Inspection Requirement')
    # plt.savefig('Plots/MinHouseholdVehicleAgeVsHouseholdSize.png', bbox_inches='tight')
    # plt.show()
    #
    # # Plot of Mean Vehicle Age
    # fig, ax = plt.subplots(1, 1, figsize=(x, y))
    # ax = sns.lineplot(data=vehData_mean, x='HHSIZE', y='VEHAGE', hue='Subject to Safety Inspection', weights='WTHHFIN',
    #                   palette=[noSafetyColor, safetyColor], legend=False)
    # plt.ylabel('Mean Household Vehicle Age')
    # plt.xlabel('Number of People In Household')
    # plt.xlim(x_lim)
    # plt.ylim(y_lim)
    # plotText(12)
    # plt.title('Mean Vehicle Age vs. Household Size\nBy Safety Inspection Requirement')
    # plt.savefig('Plots/MeanHouseholdVehicleAgeVsHouseholdSize.png', bbox_inches='tight')
    # plt.show()
    #
    # # Plot of Maximum Vehicle Age
    # fig, ax = plt.subplots(1, 1, figsize=(x, y))
    # ax = sns.lineplot(data=vehData_max, x='HHSIZE', y='VEHAGE', hue='Subject to Safety Inspection', weights='WTHHFIN',
    #                   palette=[noSafetyColor, safetyColor], legend=False)
    # plt.ylabel('Maximum Household Vehicle Age')
    # plt.xlabel('Number of People In Household')
    # plt.xlim(x_lim)
    # plt.ylim(y_lim)
    # plotText()
    # plt.title('Maximum Vehicle Age vs. Household Size\nBy Safety Inspection Requirement')
    # plt.savefig('Plots/MaxHouseholdVehicleAgeVsHouseholdSize.png', bbox_inches='tight')
    # plt.show()
    #
    # Plot number of vehicles in the house
    fig, ax = plt.subplots(1, 1, figsize=(x, y))
    ax = sns.lineplot(data=vehData_count, x='HHSIZE', y='Vehicle Count', hue='Subject to Safety Inspection', weights='WTHHFIN',
                      palette=[noSafetyColor, safetyColor], legend=False)
    plt.ylabel('Number of Vehicles')
    plt.xlabel('Number of People In Household')
    plt.xlim(x_lim)
    plt.ylim([0, 5])
    plt.text(4, 4, 'Vehicles not subject\nto safety inspections', color=noSafetyColor, ha='left')
    plt.text(3, 1, 'Vehicles subject\nto safety inspections', color=safetyColor, ha='left')
    plt.title('Number of Vehicles vs. Household Size\nBy Safety Inspection Requirement')
    plt.savefig('Plots/NumVehicleVsHouseholdSize.png', bbox_inches='tight')
    plt.show()
    #
    # Plot number of vehicles in the house
    fig, ax = plt.subplots(1, 1, figsize=(x, y))
    ax = sns.lineplot(data=vehData_count, x='HHSIZE', y='Share New', hue='Subject to Safety Inspection',
                      weights='WTHHFIN',
                      palette=[noSafetyColor, safetyColor], legend=False)
    plt.ylabel('Proportion of New Vehicles')
    plt.xlabel('Number of People In Household')
    plt.xlim(x_lim)
    plt.title('Proportion of New Vehicles vs. Household Drivers\nBy Safety Inspection Requirement')
    plt.savefig('Plots/NewShareVsHouseholdSize.png', bbox_inches='tight')
    plt.show()

    # # With respect to number of drivers in Household
    # # Plot of Minimum Vehicle Age
    # fig, ax = plt.subplots(1, 1, figsize=(x, y))
    # ax = sns.lineplot(data=vehData_min, x='DRVRCNT', y='VEHAGE', hue='Subject to Safety Inspection', weights='WTHHFIN',
    #                   palette=[noSafetyColor, safetyColor], legend=False)
    # plt.ylabel('Minimum Household Vehicle Age')
    # plt.xlabel('Number of Drivers In Household')
    # plt.xlim(x_lim)
    # plt.ylim(y_lim)
    # plotText(12, 3)
    # plt.title('Minimum Vehicle Age vs. Household Drivers\nBy Safety Inspection Requirement')
    # plt.savefig('Plots/MinHouseholdVehicleAgeVsHouseholdDrivers.png', bbox_inches='tight')
    # plt.show()
    #
    # # Plot of Mean Vehicle Age
    # fig, ax = plt.subplots(1, 1, figsize=(x, y))
    # ax = sns.lineplot(data=vehData_mean, x='DRVRCNT', y='VEHAGE', hue='Subject to Safety Inspection', weights='WTHHFIN',
    #                   palette=[noSafetyColor, safetyColor], legend=False)
    # plt.ylabel('Mean Household Vehicle Age')
    # plt.xlabel('Number of Drivers In Household')
    # plt.xlim(x_lim)
    # plt.ylim(y_lim)
    # plotText(12)
    # plt.title('Mean Vehicle Age vs. Household Drivers\nBy Safety Inspection Requirement')
    # plt.savefig('Plots/MeanHouseholdVehicleAgeVsHouseholdDrivers.png', bbox_inches='tight')
    # plt.show()
    #
    # # Plot of Maximum Vehicle Age
    # fig, ax = plt.subplots(1, 1, figsize=(x, y))
    # ax = sns.lineplot(data=vehData_max, x='DRVRCNT', y='VEHAGE', hue='Subject to Safety Inspection', weights='WTHHFIN',
    #                   palette=[noSafetyColor, safetyColor], legend=False)
    # plt.ylabel('Maximum Household Vehicle Age')
    # plt.xlabel('Number of Drivers In Household')
    # plt.xlim(x_lim)
    # plt.ylim(y_lim)
    # plotText()
    # plt.title('Maximum Vehicle Age vs. Household Drivers\nBy Safety Inspection Requirement')
    # plt.savefig('Plots/MaxHouseholdVehicleAgeVsHouseholdDrivers.png', bbox_inches='tight')
    # plt.show()
    #
    #Plot number of vehicles in the house
    fig, ax = plt.subplots(1, 1, figsize=(x, y))
    ax = sns.lineplot(data=vehData_count, x='DRVRCNT', y='Vehicle Count', hue='Subject to Safety Inspection', weights='WTHHFIN',
                      palette=[noSafetyColor, safetyColor], legend=False)
    plt.ylabel('Number of Vehicles')
    plt.xlabel('Number of Drivers In Household')
    plt.xlim(x_lim)
    plt.title('Number of Vehicles vs. Household Drivers\nBy Safety Inspection Requirement')
    plt.savefig('Plots/NumVehicleVsHouseholdDrivers.png', bbox_inches='tight')
    plt.show()

    #Plot number of vehicles in the house
    fig, ax = plt.subplots(1, 1, figsize=(x, y))
    ax = sns.lineplot(data=vehData_count, x='DRVRCNT', y='Share New', hue='Subject to Safety Inspection', weights='WTHHFIN',
                      palette=[noSafetyColor, safetyColor], legend=False)
    plt.ylabel('Proportion of New Vehicles')
    plt.xlabel('Number of Drivers In Household')
    plt.text(4, 4, 'Vehicles not subject\nto safety inspections', color=noSafetyColor, ha='left')
    plt.text(3, 1, 'Vehicles subject\nto safety inspections', color=safetyColor, ha='left')
    plt.xlim(x_lim)
    plt.title('Number of Vehicles vs. Household Drivers\nBy Safety Inspection Requirement')
    plt.savefig('Plots/NewShareVsHouseholdDrivers.png', bbox_inches='tight')
    plt.show()

    #Plot Odometers Outcomes
    fig, ax = plt.subplots(1, 1, figsize=(x, y))
    ax = sns.lineplot(data=odomDataSmall, x='VEHAGE', y='ANNMILES', hue='Subject to Safety Inspection', weights='WTHHFIN',
                      palette=[noSafetyColor, safetyColor], legend=False)
    plt.ylabel('Self-Reported Annual Miles')
    plt.xlabel('Vehicle Age')
    plt.text(1, 2500, 'Vehicles not subject\nto safety inspections', color=noSafetyColor, ha='left')
    plt.text(22, 15000 , 'Vehicles subject\nto safety inspections', color=safetyColor, ha='left')
    # plt.xlim(x_lim)
    plt.title('Self-Reported Annual Miles vs. Vehicle Age\nBy Safety Inspection Requirement')
    plt.savefig('Plots/AnnualMilesVsVehicleAge.png', bbox_inches='tight')
    plt.show()